{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "7423d594-7c2c-4a5b-8089-3faf9e30b116",
          "showTitle": false,
          "title": ""
        },
        "id": "Ozzn44gSRw5F"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "import io\n",
        "import pandas as pd\n",
        "from PIL import Image\n",
        "import torchvision.transforms as transforms\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "from torch.nn import init"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "71cb328e-20a2-42e7-95d2-ccdf9bcc238a",
          "showTitle": false,
          "title": ""
        },
        "id": "hcu64FfMRw5H"
      },
      "outputs": [],
      "source": [
        "# Ensure every computation happens on the GPU when available\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "3e279fd6-d26a-481e-b428-8629635da103",
          "showTitle": false,
          "title": ""
        },
        "id": "d2LDm8CmRw5I"
      },
      "outputs": [],
      "source": [
        "#To build the encoding and decoding functions we use the tinyshakespear dataset. However for the sake of brevity we do not pretrain the decoder model on it\n",
        "#the training function should be able to do it without an issue as well as it could take both images and text\n",
        "text_path = \"./input.txt\"\n",
        "with open(text_path, 'r', encoding='utf-8') as f:\n",
        "    text = f.read()\n",
        "\n",
        "# here are all the unique characters that occur in this text\n",
        "chars = sorted(list(set(text)))\n",
        "# create a mapping from characters to integers\n",
        "stoi = { ch:i for i,ch in enumerate(chars) }\n",
        "stoi['<pad>']= 65\n",
        "itos = { i:ch for i,ch in enumerate(chars) }\n",
        "itos[65] = '<pad>'\n",
        "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
        "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
        "vocab_size = len(stoi.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "08daa24b-a3e9-46f9-bc57-51d8d1673397",
          "showTitle": false,
          "title": ""
        },
        "id": "TH449G_WRw5I"
      },
      "outputs": [],
      "source": [
        "class PatchEmbeddings(nn.Module):\n",
        "    def __init__(self, img_size=96, patch_size=16, hidden_dim=512):\n",
        "        super().__init__()\n",
        "        self.img_size = img_size\n",
        "        self.patch_size = patch_size\n",
        "        self.num_patches = (img_size // patch_size) ** 2\n",
        "        # Ensure the convolution outputs a feature map with hidden_dim channels\n",
        "        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,\n",
        "                              kernel_size=patch_size, stride=patch_size)\n",
        "\n",
        "    def forward(self, X):\n",
        "        X = self.conv(X)\n",
        "        X = X.flatten(2)  # Flatten the patch dimensions\n",
        "        X = X.transpose(1, 2)  # [B, num_patches, hidden_dim]\n",
        "        return X\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "acec0b6d-4b53-497a-b0a4-8e8aee8270ee",
          "showTitle": false,
          "title": ""
        },
        "id": "_xJdDCXgRw5I"
      },
      "outputs": [],
      "source": [
        "#swapping linear for lazy linear for simplicity. Lazylinear can accept any arbitrary input dimension without having it specified\n",
        "\n",
        "class MLP(nn.Module):\n",
        "    def __init__(self, n_embd, dropout=0.1, is_decoder=True):\n",
        "        super().__init__()\n",
        "        layers = [\n",
        "            nn.Linear(n_embd, 4 * n_embd),\n",
        "            nn.ReLU() if is_decoder else nn.GELU(),\n",
        "            nn.Linear(4 * n_embd, n_embd),\n",
        "            nn.Dropout(dropout)\n",
        "        ]\n",
        "        self.net = nn.Sequential(*layers)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "fff7d282-c2ab-4b90-ab4a-f99fa9cd14c8",
          "showTitle": false,
          "title": ""
        },
        "id": "uqj-YL7pRw5J"
      },
      "outputs": [],
      "source": [
        "class Head(nn.Module):\n",
        "    def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):\n",
        "        super().__init__()\n",
        "        self.key = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.query = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.value = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "        self.is_decoder = is_decoder\n",
        "\n",
        "    def forward(self, x):\n",
        "        B, T, C = x.shape\n",
        "        k = self.key(x)\n",
        "        q = self.query(x)\n",
        "        v = self.value(x)\n",
        "\n",
        "        # Compute attention scores\n",
        "        wei = q @ k.transpose(-2, -1) * (C**-0.5)\n",
        "        if self.is_decoder:\n",
        "            # Ensure the mask is the correct size for the current sequence length\n",
        "            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))\n",
        "            wei = wei.masked_fill(tril == 0, float('-inf'))\n",
        "\n",
        "        # Apply softmax to get probabilities\n",
        "        wei = F.softmax(wei, dim=-1)\n",
        "        wei = self.dropout(wei)\n",
        "\n",
        "        # Perform weighted aggregation of values\n",
        "        out = wei @ v\n",
        "        return out\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "49de410a-f93b-4413-a06e-a7487379c5fc",
          "showTitle": false,
          "title": ""
        },
        "id": "-xHZhwnHRw5J"
      },
      "outputs": [],
      "source": [
        "class MultiHeadAttention(nn.Module):\n",
        "    def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):\n",
        "        super().__init__()\n",
        "        #Using assert statements for this type of checks is a good idea in general in your code\n",
        "        assert n_embd % num_heads == 0, \"n_embd must be divisible by num_heads\"\n",
        "        self.heads = nn.ModuleList([\n",
        "            Head(n_embd, n_embd // num_heads, dropout, is_decoder)\n",
        "            for _ in range(num_heads)\n",
        "        ])\n",
        "        self.proj = nn.Linear(n_embd, n_embd)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
        "        out = self.dropout(self.proj(out))\n",
        "        return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "eb477c79-fd05-4e79-8149-27b62d2b1c7e",
          "showTitle": false,
          "title": ""
        },
        "id": "xmSrzvZYRw5K"
      },
      "outputs": [],
      "source": [
        "class Block(nn.Module):\n",
        "    def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):\n",
        "        super().__init__()\n",
        "        self.ln1 = nn.LayerNorm(n_embd)\n",
        "        self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)\n",
        "        self.ln2 = nn.LayerNorm(n_embd)\n",
        "        self.ffn = nn.Sequential(\n",
        "            nn.Linear(n_embd, 4 * n_embd),\n",
        "            nn.GELU(),\n",
        "            nn.Linear(4 * n_embd, n_embd),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        original_x = x  # Save for residual connection\n",
        "        x = self.ln1(x)\n",
        "        attn_output = self.attn(x)\n",
        "        x = original_x + attn_output\n",
        "        x = self.ln2(x)\n",
        "        ffn_output = self.ffn(x)\n",
        "        x = x + ffn_output\n",
        "        return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "fd6bdd3c-3c0c-413b-8800-a036a7848492",
          "showTitle": false,
          "title": ""
        },
        "id": "bSBuAN1GRw5K"
      },
      "outputs": [],
      "source": [
        "class ViT(nn.Module):\n",
        "    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout):\n",
        "        super().__init__()\n",
        "        self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens)\n",
        "        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))\n",
        "        num_patches = (img_size // patch_size) ** 2\n",
        "        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))\n",
        "        self.dropout = nn.Dropout(emb_dropout)\n",
        "        self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])\n",
        "        self.layer_norm = nn.LayerNorm(num_hiddens)\n",
        "\n",
        "    def forward(self, X):\n",
        "        x = self.patch_embedding(X)\n",
        "        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)\n",
        "        x = torch.cat((cls_tokens, x), dim=1)\n",
        "        x += self.pos_embedding\n",
        "        x = self.dropout(x)\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        x = self.layer_norm(x[:, 0])\n",
        "        return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "c0733d1d-fb0f-4ab2-879c-9a2fd05a4dc2",
          "showTitle": false,
          "title": ""
        },
        "id": "R3z7SAf-Rw5K"
      },
      "outputs": [],
      "source": [
        "class MultiModalProjector(nn.Module):\n",
        "    def __init__(self, n_embd, image_embed_dim, dropout=0.1):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(image_embed_dim, 4 * image_embed_dim),\n",
        "            nn.GELU(),\n",
        "            nn.Linear(4 * image_embed_dim, n_embd),\n",
        "            nn.Dropout(dropout)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.net(x)\n",
        "        return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "db59fdc2-a922-4530-92f2-f51d801b3f39",
          "showTitle": false,
          "title": ""
        },
        "id": "aL9FkWjxRw5K"
      },
      "outputs": [],
      "source": [
        "class DecoderLanguageModel(nn.Module):\n",
        "    def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):\n",
        "        super().__init__()\n",
        "        self.use_images = use_images\n",
        "        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
        "        self.position_embedding_table = nn.Embedding(1000, n_embd)\n",
        "        if use_images:\n",
        "            self.image_projection = MultiModalProjector(n_embd, image_embed_dim)\n",
        "        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])\n",
        "        self.ln_f = nn.LayerNorm(n_embd)\n",
        "        self.lm_head = nn.Linear(n_embd, vocab_size)\n",
        "\n",
        "    def forward(self, idx, image_embeds=None, targets=None):\n",
        "        tok_emb = self.token_embedding_table(idx)\n",
        "        if self.use_images and image_embeds is not None:\n",
        "            img_emb = self.image_projection(image_embeds).unsqueeze(1)\n",
        "            tok_emb = torch.cat([img_emb, tok_emb], dim=1)\n",
        "        pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)\n",
        "        x = tok_emb + pos_emb\n",
        "        x = self.blocks(x)\n",
        "        x = self.ln_f(x)\n",
        "        logits = self.lm_head(x)\n",
        "        if targets is not None:\n",
        "            if self.use_images and image_embeds is not None:\n",
        "                batch_size = idx.size(0)\n",
        "                targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)\n",
        "            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)\n",
        "            return logits, loss\n",
        "        return logits\n",
        "\n",
        "    def generate(self, idx, image_embeds, max_new_tokens):\n",
        "        B, T = idx.shape\n",
        "        generated = idx\n",
        "\n",
        "        if self.use_images and image_embeds is not None:\n",
        "            img_emb = self.image_projection(image_embeds).unsqueeze(1)\n",
        "            current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)\n",
        "        else:\n",
        "            current_output = self.token_embedding_table(idx)\n",
        "\n",
        "        for i in range(max_new_tokens):\n",
        "            T_current = current_output.size(1)\n",
        "            current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)\n",
        "            current_output += current_pos_emb\n",
        "\n",
        "            for block in self.blocks:\n",
        "                current_output = block(current_output)\n",
        "\n",
        "            logits = self.lm_head(current_output[:, -1, :])\n",
        "            probs = F.softmax(logits, dim=-1)\n",
        "            idx_next = torch.multinomial(probs, num_samples=1)\n",
        "            generated = torch.cat((generated, idx_next), dim=1)\n",
        "            idx_next_emb = self.token_embedding_table(idx_next)\n",
        "            current_output = torch.cat((current_output, idx_next_emb), dim=1)\n",
        "\n",
        "        return generated\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "6702a660-0512-4152-8d56-97f0ebd562e8",
          "showTitle": false,
          "title": ""
        },
        "id": "FjklXZrPRw5K"
      },
      "outputs": [],
      "source": [
        "class VisionLanguageModel(nn.Module):\n",
        "    def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):\n",
        "        super().__init__()\n",
        "        num_hiddens = image_embed_dim  # Set num_hiddens equal to image_embed_dim\n",
        "        assert num_hiddens % num_heads == 0, \"num_hiddens must be divisible by num_heads\"\n",
        "        self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)\n",
        "        self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)\n",
        "\n",
        "    def forward(self, img_array, idx, targets=None):\n",
        "        image_embeds = self.vision_encoder(img_array)\n",
        "\n",
        "        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:\n",
        "            raise ValueError(\"somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty\")\n",
        "\n",
        "        if targets is not None:\n",
        "            logits, loss = self.decoder(idx, image_embeds, targets)\n",
        "            return logits, loss\n",
        "        else:\n",
        "            logits = self.decoder(idx, image_embeds)\n",
        "            return logits\n",
        "\n",
        "    def generate(self, img_array, idx, max_new_tokens):\n",
        "      image_embeds = self.vision_encoder(img_array)\n",
        "\n",
        "      if image_embeds.nelement() == 0 or image_embeds.shape[1] ==0:\n",
        "        raise ValueError(\"somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty\")\n",
        "\n",
        "      generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)\n",
        "      return generated_tokens"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "09d59b56-a8e7-4ff4-8f0f-c0edd527f47d",
          "showTitle": false,
          "title": ""
        },
        "id": "mhoM3Rt2Rw5L"
      },
      "outputs": [],
      "source": [
        "def base64_to_tensor(base64_str, img_size=96):\n",
        "    image = Image.open(io.BytesIO(base64.b64decode(base64_str)))\n",
        "    if image.mode != 'RGB':\n",
        "        image = image.convert('RGB')\n",
        "    transform = transforms.Compose([\n",
        "        transforms.Resize((img_size, img_size)),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
        "    ])\n",
        "    return transform(image).unsqueeze(0)  # Add batch dimension"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "45e77024-1b6f-4b25-a946-a808305f42b8",
          "showTitle": false,
          "title": ""
        },
        "id": "qF34w9-ZRw5L"
      },
      "outputs": [],
      "source": [
        "#Adjusting the data loader from makemore for multimodal data\n",
        "def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):\n",
        "    # Split data into training and validation sets\n",
        "    n = int(0.9 * len(df))  # first 90% will be train, rest val\n",
        "    df_train = df.iloc[:n]\n",
        "    df_val = df.iloc[n:]\n",
        "    data = df_train if split == 'train' else df_val\n",
        "    batch_size = batch_size if split == 'train' else val_batch_size\n",
        "    replace = False if split == 'train' else True\n",
        "    batch = data.sample(n=batch_size, replace=replace)\n",
        "\n",
        "    images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)\n",
        "    text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]\n",
        "    max_length = max(len(t) for t in text_indices)\n",
        "\n",
        "    padded_text = torch.full((batch_size, max_length), fill_value=stoi['<pad>'], dtype=torch.long).to(device)\n",
        "    for i, text in enumerate(text_indices):\n",
        "        padded_text[i, :len(text)] = text\n",
        "\n",
        "    targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)\n",
        "\n",
        "    # Truncate or pad targets to match the length of padded_text\n",
        "    if targets.size(1) > padded_text.size(1):\n",
        "        targets = targets[:, :padded_text.size(1)]\n",
        "    elif targets.size(1) < padded_text.size(1):\n",
        "        targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)\n",
        "\n",
        "    return images, padded_text, targets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "605036de-0c47-4dc8-9328-1a8ad9c889f8",
          "showTitle": false,
          "title": ""
        },
        "id": "IzBrV3ZWRw5L"
      },
      "outputs": [],
      "source": [
        "#Adjusting the training loop from makemore for multimodal data\n",
        "def train_model(model, df, epochs, vocab_size, img_size=96):\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
        "    model.to(device)\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for _ in range(max_iters):\n",
        "            images, idx, targets = get_batch(df, batch_size, 'train', img_size)\n",
        "            optimizer.zero_grad()\n",
        "            logits, loss = model(images, idx, targets)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            if _ % eval_interval == 0:\n",
        "                print(f\"Loss at iteration {_}: {loss.item()}\")\n",
        "        val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)\n",
        "        print(f\"Validation Loss after epoch {epoch}: {val_loss}\")\n",
        "\n",
        "def estimate_loss(model, df, split, img_size=96, val_batch_size=8):\n",
        "    losses = []\n",
        "    model.eval()\n",
        "    for _ in range(eval_iters):\n",
        "        images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)\n",
        "        _, loss = model(images, idx, targets)\n",
        "        losses.append(loss.item())\n",
        "    return sum(losses) / len(losses)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "1670d54d-791d-4d03-945a-12aa634af044",
          "showTitle": false,
          "title": ""
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ITWGvRmkRw5L",
        "outputId": "b5aac3ec-6c6a-442a-f7df-84710c1074fc"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(90, 2)"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ],
      "source": [
        "df = pd.read_csv(\"./inputs.csv\")\n",
        "#Expanding dataframe so that there's enough data to test. This is just duplicating data. A real dataset would have more rows\n",
        "df = pd.concat([df] * 30)[['b64string_images', 'caption']]\n",
        "df.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "3d948df0-4979-46e1-8208-629e8a95c93b",
          "showTitle": false,
          "title": ""
        },
        "id": "MFM46Cs0Rw5L"
      },
      "outputs": [],
      "source": [
        "batch_size = 16 # how many independent sequences will we process in parallel?\n",
        "block_size = 32 # what is the maximum context length for predictions?\n",
        "max_iters = 100\n",
        "eval_interval = 10\n",
        "learning_rate = 1e-3\n",
        "epochs=1\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "eval_iters = 40\n",
        "num_blks= 3\n",
        "head_size = 16\n",
        "n_embd = 128\n",
        "n_head = 8\n",
        "n_layer = 8\n",
        "dropout = 0.1\n",
        "img_size=96\n",
        "patch_size =16\n",
        "image_embed_dim = 512\n",
        "emb_dropout = blk_dropout =0.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "04e27881-aaee-4f40-b5cd-4e6a2de0fca7",
          "showTitle": false,
          "title": ""
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ayLj98T7Rw5L",
        "outputId": "0c541a39-c9e2-428e-f205-4d3608685722"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loss at iteration 0: 4.551465034484863\n",
            "Loss at iteration 10: 0.541579008102417\n",
            "Loss at iteration 20: 0.1060851439833641\n",
            "Loss at iteration 30: 0.035827286541461945\n",
            "Loss at iteration 40: 0.03595846891403198\n"
          ]
        }
      ],
      "source": [
        "# Initialize the model\n",
        "model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, n_head, num_blks, emb_dropout, blk_dropout)\n",
        "model.to(device)\n",
        "\n",
        "# Dummy data to initialize lazy modules\n",
        "dummy_img = torch.randn(1, 3, img_size, img_size).to(device)\n",
        "dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)\n",
        "model(dummy_img, dummy_idx)  # Forward pass to initialize all parameters\n",
        "\n",
        "# Train the model\n",
        "train_model(model, df, epochs, vocab_size, img_size)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "beff4f71-70ae-45f2-b51b-ce8e3e3085f2",
          "showTitle": false,
          "title": ""
        },
        "id": "sslq5qWDRw5L"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "application/vnd.databricks.v1+notebook": {
      "dashboards": [],
      "language": "python",
      "notebookMetadata": {
        "pythonIndentUnit": 2
      },
      "notebookName": "seemore",
      "widgets": {}
    },
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "language_info": {
      "name": "python"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}